import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from mmcv.cnn import ConvModule, Scale
from mmcv.runner import BaseModule, force_fp32
from mmcv.image import tensor2imgs
from skimage import color



INF = 1e8


def compute_pairwise_term(mask_logits, pairwise_size, pairwise_dilation):
    assert mask_logits.dim() == 4

    log_fg_prob = F.logsigmoid(mask_logits)
    log_bg_prob = F.logsigmoid(-mask_logits)

    log_fg_prob_unfold = unfold_wo_center(
        log_fg_prob, kernel_size=pairwise_size,
        dilation=pairwise_dilation
    )
    log_bg_prob_unfold = unfold_wo_center(
        log_bg_prob, kernel_size=pairwise_size,
        dilation=pairwise_dilation
    )

    # the probability of making the same prediction = p_i * p_j + (1 - p_i) * (1 - p_j)
    # we compute the the probability in log space to avoid numerical instability
    log_same_fg_prob = log_fg_prob[:, :, None] + log_fg_prob_unfold
    log_same_bg_prob = log_bg_prob[:, :, None] + log_bg_prob_unfold

    # this equation is equal to log(p_i * p_j + (1 - p_i) * (1 - p_j))
    # max is used to prevent overflow
    max_ = torch.max(log_same_fg_prob, log_same_bg_prob)  #
    log_same_prob = torch.log(
        torch.exp(log_same_fg_prob - max_) +
        torch.exp(log_same_bg_prob - max_)
    ) + max_

    return -log_same_prob[:, 0]


def dice_coefficient(x, target):
    """
    Dice Loss: 1 - 2 * (intersection(A, B) / (A^2 + B^2))
    :param x:
    :param target:
    :return:
    """
    eps = 1e-5
    n_inst = x.size(0)
    x = x.reshape(n_inst, -1)
    target = target.reshape(n_inst, -1)
    intersection = (x * target).sum(dim=1)
    union = (x ** 2.0).sum(dim=1) + (target ** 2.0).sum(dim=1) + eps
    loss = 1. - (2 * intersection / union)
    return loss

def compute_project_term(mask_scores, gt_bitmasks):
    mask_losses_y = dice_coefficient(
        mask_scores.max(dim=2, keepdim=True)[0],
        gt_bitmasks.max(dim=2, keepdim=True)[0]
    )
    mask_losses_x = dice_coefficient(
        mask_scores.max(dim=3, keepdim=True)[0],
        gt_bitmasks.max(dim=3, keepdim=True)[0]
    )
    return (mask_losses_x + mask_losses_y).mean()


def aligned_bilinear(tensor, factor):
    assert tensor.dim() == 4
    assert factor >= 1
    assert isinstance(factor, int)

    if factor == 1:
        return tensor

    h, w = tensor.size()[2:]
    tensor = F.pad(tensor, pad=(0, 1, 0, 1), mode="replicate")
    oh = factor * h + 1
    ow = factor * w + 1
    tensor = F.interpolate(
        tensor, size=(oh, ow),
        mode='bilinear',
        align_corners=True
    )
    tensor = F.pad(
        tensor, pad=(factor // 2, 0, factor // 2, 0),
        mode="replicate"
    )
    return tensor[:, :, :oh - 1, :ow - 1]

def unfold_wo_center(x, kernel_size, dilation):
    """
    :param x: [N, C, H, W]
    :param kernel_size: k
    :param dilation:
    :return: [N, C, K^2-1, H, W]
    """
    assert x.dim() == 4
    assert kernel_size % 2 == 1

    # using SAME padding
    padding = (kernel_size + (dilation - 1) * (kernel_size - 1)) // 2
    unfolded_x = F.unfold(
        x, kernel_size=kernel_size,
        padding=padding,
        dilation=dilation
    )

    unfolded_x = unfolded_x.reshape(
        x.size(0), x.size(1), -1, x.size(2), x.size(3)
    )

    # remove the center pixels
    size = kernel_size ** 2
    unfolded_x = torch.cat((unfolded_x[:, :, :size // 2], unfolded_x[:, :, size // 2 + 1:]), dim=2)

    return unfolded_x


def get_image_color_similarity(image, mask, pairwise_size, pairwise_dilation):
    """
    \
    :param self:
    :param image: [1, 3, H, W]
    :param mask: [H, W]
    :param pairwise_size: k
    :param pairwise_dilation: d
    :return:[1, 8, H, W]
    """
    assert image.dim() == 4
    assert image.size(0) == 1

    unfolded_image = unfold_wo_center(
        image, kernel_size=pairwise_size, dilation=pairwise_dilation
    )

    diff = image.unsqueeze(2) - unfolded_image

    similarity = torch.exp(-torch.norm(diff, dim=1) * 0.5)

    unfolded_weight = unfold_wo_center(
        mask.unsqueeze(0).unsqueeze(0),
        kernel_size=pairwise_size, dilation=pairwise_dilation
    )[:, 0, :, :, :]

    return similarity * unfolded_weight

def get_bitmasks_from_boxes( gt_bboxes, padded_images, padded_image_masks):
    h, w = padded_images.shape[2:]
    stride = self.out_stride
    start = int(stride // 2)

    assert padded_images.size(2) % stride == 0
    assert padded_images.size(3) % stride == 0

    downsampled_images = F.avg_pool2d(padded_images.float(), kernel_size=stride, stride=stride, padding=0)
    downsampled_image_masks = padded_image_masks[:, start::stride, start::stride]

    similarities = []
    bitmasks = []
    bitmasks_full = []

    for i, per_img_gt_bboxes in enumerate(gt_bboxes):
        image_lab = color.rgb2lab(downsampled_images[i].byte().permute(1, 2, 0).cpu().numpy())
        image_lab = torch.as_tensor(image_lab, device=padded_image_masks.device, dtype=torch.float32)
        image_lab = image_lab.permute(2, 0, 1)[None]
        image_color_similarity = get_image_color_similarity(
            image_lab, downsampled_image_masks[i],
            self.pairwise_size, self.pairwise_dilation
        )
    
        per_im_bitmasks = []
        per_im_bitmasks_full = []
    
        for per_box in per_img_gt_bboxes:  # [x1,y1, x2, y2]
            bitmask_full = torch.zeros((h, w), device=per_box.device).float()
            bitmask_full[int(per_box[1]): int(per_box[3]) + 1, int(per_box[0]):int(per_box[2]) + 1] = 1.0
        
            bitmask = bitmask_full[start::stride, start::stride]
        
            assert bitmask.size(0) * stride == h
            assert bitmask.size(1) * stride == w
        
            per_im_bitmasks.append(bitmask)
            per_im_bitmasks_full.append(bitmask_full)
    
        per_im_bitmasks = torch.stack(per_im_bitmasks, dim=0)
        per_im_bitmasks_full = torch.stack(per_im_bitmasks_full, dim=0)
    
        similarities.append(torch.cat([image_color_similarity for _ in range(len(per_img_gt_bboxes))], dim=0))
        bitmasks.append(per_im_bitmasks)
        bitmasks_full.append(per_im_bitmasks_full)

    return similarities, bitmasks, bitmasks_full